{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bcf54684",
   "metadata": {},
   "source": [
    "# 4-2,张量的数学运算\n",
    "\n",
    "张量数学运算主要有：标量运算，向量运算，矩阵运算，以及使用非常强大而灵活的爱因斯坦求和函数torch.einsum进行任意维的张量运算。\n",
    "\n",
    "此外我们还会介绍张量运算的广播机制。\n",
    "\n",
    "本篇文章内容如下：\n",
    "\n",
    "* 标量运算\n",
    "\n",
    "* 向量运算\n",
    "\n",
    "* 矩阵运算\n",
    "\n",
    "* 任意维张量运算(torch.einsum)\n",
    "\n",
    "* 广播机制\n",
    "\n",
    "本节中的torch.einsum的理解是重难点。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a6c5e1c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.__version__=2.4.0\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "print(\"torch.__version__=\"+torch.__version__) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "831e4e2c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "20ef9293",
   "metadata": {},
   "source": [
    "### 一，标量运算 (操作的张量至少是0维)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db35231b",
   "metadata": {},
   "source": [
    "张量的数学运算符可以分为标量运算符、向量运算符、以及矩阵运算符。\n",
    "\n",
    "加减乘除乘方，以及三角函数，指数，对数等常见函数，逻辑比较运算符等都是标量运算符。\n",
    "\n",
    "标量运算符的特点是对张量实施逐元素运算。\n",
    "\n",
    "有些标量运算符对常用的数学运算符进行了重载。并且支持类似numpy的广播特性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "734173dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import numpy as np \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ea5c924",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.tensor(1.0)\n",
    "b = torch.tensor(2.0)\n",
    "a+b \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8bff7b5f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 6.,  8.],\n",
       "        [ 4., 12.]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.tensor([[1.0,2],[-3,4.0]])\n",
    "b = torch.tensor([[5.0,6],[7.0,8.0]])\n",
    "a+b  #运算符重载\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5b2a982e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -4.,  -4.],\n",
       "        [-10.,  -4.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a-b "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7b7edda7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  5.,  12.],\n",
       "        [-21.,  32.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a*b "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "80742f80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2000,  0.3333],\n",
       "        [-0.4286,  0.5000]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a/b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f8aa2e1c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  4.],\n",
       "        [ 9., 16.]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d944cfe6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 1.4142],\n",
       "        [   nan, 2.0000]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a**(0.5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "25b2c4cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2.],\n",
       "        [-0., 1.]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a%3 #求模\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2071e6c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.,  0.],\n",
       "        [-1.,  0.]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.div(a, b, rounding_mode='floor')  #地板除法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "96ee9ea3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False,  True],\n",
       "        [False,  True]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a>=2 # torch.ge(a,2)  #ge: greater_equal缩写"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "568683e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False,  True],\n",
       "        [False, False]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a>=2)&(a<=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dbdb32cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True],\n",
       "        [True, True]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a>=2)|(a<=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "889af586",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False],\n",
       "        [False, False]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a==5 #torch.eq(a,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d626b257",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 1.4142],\n",
       "        [   nan, 2.0000]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sqrt(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3261aaae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([12., 21.])\n"
     ]
    }
   ],
   "source": [
    "a = torch.tensor([1.0,8.0])\n",
    "b = torch.tensor([5.0,6.0])\n",
    "c = torch.tensor([6.0,7.0])\n",
    "\n",
    "d = a+b+c\n",
    "print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "925f1286",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5., 8.])\n"
     ]
    }
   ],
   "source": [
    "print(torch.max(a,b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c3ef7bd1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 6.])\n"
     ]
    }
   ],
   "source": [
    "print(torch.min(a,b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "013b21fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 3., -3.])\n",
      "tensor([ 2., -3.])\n",
      "tensor([ 3., -2.])\n",
      "tensor([ 2., -2.])\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([2.6,-2.7])\n",
    "\n",
    "print(torch.round(x)) #保留整数部分，四舍五入\n",
    "print(torch.floor(x)) #保留整数部分，向下归整\n",
    "print(torch.ceil(x))  #保留整数部分，向上归整\n",
    "print(torch.trunc(x)) #保留整数部分，向0归整\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "2cf97ecf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.6000, -0.7000])\n",
      "tensor([0.6000, 1.3000])\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([2.6,-2.7])\n",
    "print(torch.fmod(x,2)) #作除法取余数 \n",
    "print(torch.remainder(x,2)) #作除法取剩余的部分，结果恒正\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9567f956",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.9000, -0.8000,  1.0000, -1.0000,  0.7000])\n",
      "tensor([  0.9000,  -0.8000,   1.0000, -20.0000,   0.7000])\n"
     ]
    }
   ],
   "source": [
    "# 幅值裁剪\n",
    "x = torch.tensor([0.9,-0.8,100.0,-20.0,0.7])\n",
    "y = torch.clamp(x,min=-1,max = 1)\n",
    "z = torch.clamp(x,max = 1)\n",
    "print(y)\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "a75c0a7f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(5.)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relu = lambda x:x.clamp(min=0.0)\n",
    "relu(torch.tensor(5.0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5efc17",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4de3bfd8",
   "metadata": {},
   "source": [
    "### 二，向量运算 （原则上操作的张量至少是一维张量）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25a6eac8",
   "metadata": {},
   "source": [
    "向量运算符只在一个特定轴上运算，将一个向量映射到一个标量或者另外一个向量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "56781e21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(45.)\n",
      "tensor(5.)\n",
      "tensor(9.)\n",
      "tensor(1.)\n",
      "tensor(362880.)\n",
      "tensor(2.7386)\n",
      "tensor(7.5000)\n",
      "tensor(5.)\n"
     ]
    }
   ],
   "source": [
    "#统计值\n",
    "\n",
    "a = torch.arange(1,10).float().view(3,3)\n",
    "\n",
    "print(torch.sum(a))\n",
    "print(torch.mean(a))\n",
    "print(torch.max(a))\n",
    "print(torch.min(a))\n",
    "print(torch.prod(a)) #累乘\n",
    "print(torch.std(a))  #标准差\n",
    "print(torch.var(a))  #方差\n",
    "print(torch.median(a)) #中位数\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e5d1cce1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 2., 3.],\n",
      "        [4., 5., 6.],\n",
      "        [7., 8., 9.]])\n",
      "torch.return_types.max(\n",
      "values=tensor([7., 8., 9.]),\n",
      "indices=tensor([2, 2, 2]))\n",
      "torch.return_types.max(\n",
      "values=tensor([3., 6., 9.]),\n",
      "indices=tensor([2, 2, 2]))\n"
     ]
    }
   ],
   "source": [
    "#指定维度计算统计值\n",
    "b = torch.arange(1,10).float().view(3,3)\n",
    "print(b)\n",
    "print(torch.max(b,dim = 0))\n",
    "print(torch.max(b,dim = 1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b15460d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 1,  3,  6, 10, 15, 21, 28, 36, 45])\n",
      "tensor([     1,      2,      6,     24,    120,    720,   5040,  40320, 362880])\n",
      "tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])\n",
      "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])\n",
      "torch.return_types.cummin(\n",
      "values=tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),\n",
      "indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0]))\n"
     ]
    }
   ],
   "source": [
    "#cum扫描\n",
    "a = torch.arange(1,10)\n",
    "\n",
    "print(torch.cumsum(a,0))\n",
    "print(torch.cumprod(a,0))\n",
    "print(torch.cummax(a,0).values)\n",
    "print(torch.cummax(a,0).indices)\n",
    "print(torch.cummin(a,0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0075688b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.return_types.topk(\n",
      "values=tensor([[9., 7., 8.],\n",
      "        [5., 6., 4.]]),\n",
      "indices=tensor([[0, 0, 0],\n",
      "        [2, 2, 2]])) \n",
      "\n",
      "torch.return_types.topk(\n",
      "values=tensor([[9., 8.],\n",
      "        [3., 2.],\n",
      "        [6., 5.]]),\n",
      "indices=tensor([[0, 2],\n",
      "        [1, 2],\n",
      "        [1, 0]])) \n",
      "\n",
      "torch.return_types.sort(\n",
      "values=tensor([[7., 8., 9.],\n",
      "        [1., 2., 3.],\n",
      "        [4., 5., 6.]]),\n",
      "indices=tensor([[1, 2, 0],\n",
      "        [0, 2, 1],\n",
      "        [2, 0, 1]])) \n",
      "\n"
     ]
    }
   ],
   "source": [
    "#torch.sort和torch.topk可以对张量排序\n",
    "a = torch.tensor([[9,7,8],[1,3,2],[5,6,4]]).float()\n",
    "print(torch.topk(a,2,dim = 0),\"\\n\")\n",
    "print(torch.topk(a,2,dim = 1),\"\\n\")\n",
    "print(torch.sort(a,dim = 1),\"\\n\")\n",
    "\n",
    "#利用torch.topk可以在Pytorch中实现KNN算法\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "838ad3f6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1758b43f",
   "metadata": {},
   "source": [
    "### 三，矩阵运算 （操作的张量至少是二维张量）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "556953cb",
   "metadata": {},
   "source": [
    "矩阵必须是二维的。类似torch.tensor([1,2,3])这样的不是矩阵。\n",
    "\n",
    "矩阵运算包括：矩阵乘法，矩阵逆，矩阵求迹，矩阵范数，矩阵行列式，矩阵求特征值，矩阵分解等运算。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "a8336d39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2, 4],\n",
      "        [6, 8]])\n"
     ]
    }
   ],
   "source": [
    "#矩阵乘法\n",
    "a = torch.tensor([[1,2],[3,4]])\n",
    "b = torch.tensor([[2,0],[0,2]])\n",
    "print(a@b)  #等价于torch.matmul(a,b) 或 torch.mm(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a0b62312",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 5, 4])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#高维张量的矩阵乘法在后面的维度上进行\n",
    "a = torch.randn(5,5,6)\n",
    "b = torch.randn(5,6,4)\n",
    "(a@b).shape \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3b114762",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 3.],\n",
      "        [2., 4.]])\n"
     ]
    }
   ],
   "source": [
    "#矩阵转置\n",
    "a = torch.tensor([[1.0,2],[3,4]])\n",
    "print(a.t())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d43f10f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-2.0000,  1.0000],\n",
      "        [ 1.5000, -0.5000]])\n"
     ]
    }
   ],
   "source": [
    "#矩阵逆，必须为浮点类型\n",
    "a = torch.tensor([[1.0,2],[3,4]])\n",
    "print(torch.inverse(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "0cfa03e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(5.)\n"
     ]
    }
   ],
   "source": [
    "#矩阵求trace\n",
    "a = torch.tensor([[1.0,2],[3,4]])\n",
    "print(torch.trace(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ccf1a509",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(5.4772)\n"
     ]
    }
   ],
   "source": [
    "#矩阵求范数\n",
    "a = torch.tensor([[1.0,2],[3,4]])\n",
    "print(torch.norm(a))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "55f60815",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-2.)\n"
     ]
    }
   ],
   "source": [
    "#矩阵行列式\n",
    "a = torch.tensor([[1.0,2],[3,4]])\n",
    "print(torch.det(a))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "6b7b5646",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.return_types.linalg_eig(\n",
      "eigenvalues=tensor([2.5000+2.7839j, 2.5000-2.7839j]),\n",
      "eigenvectors=tensor([[0.2535-0.4706j, 0.2535+0.4706j],\n",
      "        [0.8452+0.0000j, 0.8452-0.0000j]]))\n"
     ]
    }
   ],
   "source": [
    "#矩阵特征值和特征向量\n",
    "a = torch.tensor([[1.0,2],[-5,4]],dtype = torch.float)\n",
    "print(torch.linalg.eig(a))\n",
    "\n",
    "#两个特征值分别是 -2.5+2.7839j, 2.5-2.7839j \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "a30bcd31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.3162, -0.9487],\n",
      "        [-0.9487,  0.3162]]) \n",
      "\n",
      "tensor([[-3.1623, -4.4272],\n",
      "        [ 0.0000, -0.6325]]) \n",
      "\n",
      "tensor([[1.0000, 2.0000],\n",
      "        [3.0000, 4.0000]])\n"
     ]
    }
   ],
   "source": [
    "#矩阵QR分解, 将一个方阵分解为一个正交矩阵q和上三角矩阵r\n",
    "#QR分解实际上是对矩阵a实施Schmidt正交化得到q\n",
    "\n",
    "a  = torch.tensor([[1.0,2.0],[3.0,4.0]])\n",
    "q,r = torch.linalg.qr(a)\n",
    "print(q,\"\\n\")\n",
    "print(r,\"\\n\")\n",
    "print(q@r)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b857857e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "72fee65c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.2298,  0.8835,  0.4082],\n",
      "        [-0.5247,  0.2408, -0.8165],\n",
      "        [-0.8196, -0.4019,  0.4082]]) \n",
      "\n",
      "tensor([9.5255, 0.5143]) \n",
      "\n",
      "tensor([[-0.6196, -0.7849],\n",
      "        [-0.7849,  0.6196]]) \n",
      "\n",
      "tensor([[1.0000, 2.0000],\n",
      "        [3.0000, 4.0000],\n",
      "        [5.0000, 6.0000]])\n"
     ]
    }
   ],
   "source": [
    "#矩阵svd分解\n",
    "#svd分解可以将任意一个矩阵分解为一个正交矩阵u,一个对角阵s和一个正交矩阵v.t()的乘积\n",
    "#svd常用于矩阵压缩和降维\n",
    "a=torch.tensor([[1.0,2.0],[3.0,4.0],[5.0,6.0]])\n",
    "\n",
    "u,s,v = torch.linalg.svd(a)\n",
    "\n",
    "print(u,\"\\n\")\n",
    "print(s,\"\\n\")\n",
    "print(v,\"\\n\")\n",
    "\n",
    "import torch.nn.functional as F \n",
    "print(u@F.pad(torch.diag(s),(0,0,0,1))@v.t())\n",
    "\n",
    "#利用svd分解可以在Pytorch中实现主成分分析降维\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee33478",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b628faa8",
   "metadata": {},
   "source": [
    "### 四，任意维张量运算(torch.einsum) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66ce0b69",
   "metadata": {},
   "source": [
    "如果问pytorch中最强大的一个数学函数是什么？\n",
    "\n",
    "我会说是torch.einsum：爱因斯坦求和函数。\n",
    "\n",
    "它几乎是一个\"万能函数\"：能实现超过一万种功能的函数。\n",
    "\n",
    "不仅如此，和其它pytorch中的函数一样，torch.einsum是支持求导和反向传播的，并且计算效率非常高。\n",
    "\n",
    "einsum 提供了一套既简洁又优雅的规则，可实现包括但不限于：内积，外积，矩阵乘法，转置和张量收缩（tensor contraction）等张量操作，熟练掌握 einsum 可以很方便的实现复杂的张量操作，而且不容易出错。\n",
    "\n",
    "尤其是在一些包括batch维度的高阶张量的相关计算中，若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题，若换成einsum则会特别简单。\n",
    "\n",
    "套用一句深度学习paper标题当中非常时髦的话术，einsum is all you needed 😋！\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f42904",
   "metadata": {},
   "source": [
    "#### 1，einsum规则原理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e148abf0",
   "metadata": {},
   "source": [
    "顾名思义，einsum这个函数的思想起源于家喻户晓的小爱同学：爱因斯坦~。\n",
    "\n",
    "很久很久以前，小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。\n",
    "\n",
    "比如描述时空有一个四维时空度规张量，描述电磁场有一个电磁张量，描述运动的有能量动量张量。\n",
    "\n",
    "在理论物理学家中，小爱同学的数学基础不算特别好，在捣鼓这些张量的时候，他遇到了一个比较头疼的问题：公式太长太复杂了。\n",
    "\n",
    "有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢，能不能减少一些那种扭曲的$\\sum$求和符号呢？\n",
    "\n",
    "小爱发现，求和导致维度收缩，因此求和符号操作的指标总是只出现在公式的一边。\n",
    "\n",
    "例如在我们熟悉的矩阵乘法中\n",
    "\n",
    "$$C_{ij} = \\sum_{k} A_{ik} B_{kj}$$\n",
    "\n",
    "k这个下标被求和了，求和导致了这个维度的消失，所以它只出现在右边而不出现在左边。\n",
    "\n",
    "这种只出现在张量公式的一边的下标被称之为哑指标，反之为自由指标。\n",
    "\n",
    "小爱同学脑瓜子一转，反正这种只出现在一边的哑指标一定是被求和求掉的，干脆把对应的$\\sum$求和符号省略得了。\n",
    "\n",
    "这就是爱因斯坦求和约定：\n",
    "\n",
    " <font color=\"red\">**只出现在公式一边的指标叫做哑指标，针对哑指标的$\\sum$求和符号可以省略。** </font> \n",
    "\n",
    "公式立刻清爽了很多。\n",
    "\n",
    "$$C_{ij} =  A_{ik} B_{kj}$$\n",
    "\n",
    "这个公式表达的含义如下：\n",
    "\n",
    "C这个张量的第i行第j列由$A$这个张量的第i行第k列和$B$这个张量的第k行第j列相乘，这样得到的是一个三维张量$D$, 其元素为$D_{ikj}$，然后对$D$在维度k上求和得到。\n",
    "\n",
    "公式展现形式中除了省去了求和符号，还省去了乘法符号(代数通识)。\n",
    "\n",
    "\n",
    "借鉴爱因斯坦求和约定表达张量运算的清爽整洁，numpy、tensorflow和 torch等库中都引入了 einsum这个函数。\n",
    "\n",
    "上述矩阵乘法可以被einsum这个函数表述成\n",
    "\n",
    "\n",
    "```python\n",
    "C = torch.einsum(\"ik,kj->ij\",A,B)\n",
    "```\n",
    "\n",
    "这个函数的规则原理非常简洁，3句话说明白。\n",
    "\n",
    "<font color=\"red\">\n",
    "   \n",
    "* 1，用元素计算公式来表达张量运算。\n",
    "\n",
    "* 2，只出现在元素计算公式箭头左边的指标叫做哑指标。\n",
    "\n",
    "* 3，省略元素计算公式中对哑指标的求和符号。\n",
    "    \n",
    "</font> \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "42f733fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[19., 22.],\n",
      "        [43., 50.]])\n",
      "tensor([[19., 22.],\n",
      "        [43., 50.]])\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "\n",
    "A = torch.tensor([[1,2],[3,4.0]])\n",
    "B = torch.tensor([[5,6],[7,8.0]])\n",
    "\n",
    "C1 = A@B\n",
    "print(C1)\n",
    "\n",
    "C2 = torch.einsum(\"ik,kj->ij\",[A,B])\n",
    "print(C2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "829a512c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "571ac560",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "66da7f8c",
   "metadata": {},
   "source": [
    "#### 2，einsum基础范例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2263f20c",
   "metadata": {},
   "source": [
    "einsum这个函数的精髓实际上是第一条:\n",
    "\n",
    "用元素计算公式来表达张量运算。\n",
    "\n",
    "而绝大部分张量运算都可以用元素计算公式很方便地来表达，这也是它为什么会那么神通广大。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "db60fd1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([3, 4, 5])\n",
      "after: torch.Size([3, 5, 4])\n"
     ]
    }
   ],
   "source": [
    "#例1，张量转置\n",
    "A = torch.randn(3,4,5)\n",
    "\n",
    "#B = torch.permute(A,[0,2,1])\n",
    "B = torch.einsum(\"ijk->ikj\",A) \n",
    "\n",
    "print(\"before:\",A.shape)\n",
    "print(\"after:\",B.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "6c89972a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([5, 5])\n",
      "after: torch.Size([5])\n"
     ]
    }
   ],
   "source": [
    "#例2，取对角元\n",
    "A = torch.randn(5,5)\n",
    "#B = torch.diagonal(A)\n",
    "B = torch.einsum(\"ii->i\",A)\n",
    "print(\"before:\",A.shape)\n",
    "print(\"after:\",B.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "86ef3f2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([4, 5])\n",
      "after: torch.Size([4])\n"
     ]
    }
   ],
   "source": [
    "#例3，求和降维\n",
    "A = torch.randn(4,5)\n",
    "#B = torch.sum(A,1)\n",
    "B = torch.einsum(\"ij->i\",A)\n",
    "print(\"before:\",A.shape)\n",
    "print(\"after:\",B.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80051a89",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "68f0cd2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([5, 5]) torch.Size([5, 5])\n",
      "after: torch.Size([5, 5])\n"
     ]
    }
   ],
   "source": [
    "#例4，哈达玛积\n",
    "A = torch.randn(5,5)\n",
    "B = torch.randn(5,5)\n",
    "#C=A*B\n",
    "C = torch.einsum(\"ij,ij->ij\",A,B)\n",
    "print(\"before:\",A.shape, B.shape)\n",
    "print(\"after:\",C.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c759e5c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "873afb6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([10]) torch.Size([10])\n",
      "after: torch.Size([])\n"
     ]
    }
   ],
   "source": [
    "#例5，向量内积\n",
    "A = torch.randn(10)\n",
    "B = torch.randn(10)\n",
    "#C=torch.dot(A,B)\n",
    "C = torch.einsum(\"i,i->\",A,B)\n",
    "print(\"before:\",A.shape, B.shape)\n",
    "print(\"after:\",C.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66ce6fdd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "7f74e25a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([10]) torch.Size([5])\n",
      "after: torch.Size([10, 5])\n"
     ]
    }
   ],
   "source": [
    "#例6，向量外积(类似笛卡尔积)\n",
    "A = torch.randn(10)\n",
    "B = torch.randn(5)\n",
    "#C = torch.outer(A,B)\n",
    "C = torch.einsum(\"i,j->ij\",A,B)\n",
    "print(\"before:\",A.shape, B.shape)\n",
    "print(\"after:\",C.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47729690",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "2e659e48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([5, 4]) torch.Size([4, 6])\n",
      "after: torch.Size([5, 6])\n"
     ]
    }
   ],
   "source": [
    "#例7，矩阵乘法\n",
    "A = torch.randn(5,4)\n",
    "B = torch.randn(4,6)\n",
    "#C = torch.matmul(A,B)\n",
    "C = torch.einsum(\"ik,kj->ij\",A,B)\n",
    "print(\"before:\",A.shape, B.shape)\n",
    "print(\"after:\",C.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "486852a8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "4db4479d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])\n",
      "after: torch.Size([5, 6])\n"
     ]
    }
   ],
   "source": [
    "#例8，张量缩并\n",
    "A = torch.randn(3,4,5)\n",
    "B = torch.randn(4,3,6)\n",
    "#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])\n",
    "C = torch.einsum(\"ijk,jih->kh\",A,B)\n",
    "print(\"before:\",A.shape, B.shape)\n",
    "print(\"after:\",C.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8379a983",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8fca2cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "58b5ca1b",
   "metadata": {},
   "source": [
    "#### 3，einsum高级范例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ceeb01c6",
   "metadata": {},
   "source": [
    "einsum可用于超过两个张量的计算。\n",
    "\n",
    "例如：双线性变换。这是向量内积的一种扩展，一种常用的注意力机制实现方式)\n",
    "\n",
    "不考虑batch维度时，双线性变换的公式如下: \n",
    "\n",
    "$$A = qWk^T $$  \n",
    "\n",
    "考虑batch维度时，无法用矩阵乘法表示，可以用元素计算公式表达如下：\n",
    "\n",
    "$$A_{ij} = \\sum_{k}\\sum_{l}Q_{ik}W_{jkl}K_{il} = Q_{ik}W_{jkl}K_{il}$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "63935d22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a.shape: torch.Size([5])\n",
      "A.shape: torch.Size([8, 5])\n"
     ]
    }
   ],
   "source": [
    "#例9，bilinear注意力机制\n",
    "\n",
    "#====不考虑batch维度====\n",
    "q = torch.randn(10) #query_features\n",
    "k = torch.randn(10) #key_features\n",
    "W = torch.randn(5,10,10) #out_features,query_features,key_features\n",
    "b = torch.randn(5) #out_features\n",
    "\n",
    "#a = q@W@k.t()+b  \n",
    "a = torch.bilinear(q,k,W,b)\n",
    "print(\"a.shape:\",a.shape)\n",
    "\n",
    "\n",
    "#=====考虑batch维度====\n",
    "Q = torch.randn(8,10)    #batch_size,query_features\n",
    "K = torch.randn(8,10)    #batch_size,key_features\n",
    "W = torch.randn(5,10,10) #out_features,query_features,key_features\n",
    "b = torch.randn(5)       #out_features\n",
    "\n",
    "#A = torch.bilinear(Q,K,W,b)\n",
    "A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b\n",
    "print(\"A.shape:\",A.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6eabbda",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "453669da",
   "metadata": {},
   "source": [
    "我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.\n",
    "\n",
    "\n",
    "不考虑batch维度时，scaled-dot-product形式的Attention用矩阵乘法公式表示如下: \n",
    "\n",
    "$$ a = softmax(\\frac{q k^T}{d_k})$$\n",
    "\n",
    "考虑batch维度时，无法用矩阵乘法表示，可以用元素计算公式表达如下：\n",
    "\n",
    "$$ A_{ij} = softmax(\\frac{Q_{in}K_{ijn}}{d_k})$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "6bc68830",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a.shape= torch.Size([6])\n",
      "A.shape= torch.Size([8, 6])\n"
     ]
    }
   ],
   "source": [
    "#例10，scaled-dot-product注意力机制\n",
    "\n",
    "#====不考虑batch维度====\n",
    "q = torch.randn(10)  #query_features\n",
    "k = torch.randn(6,10) #key_size, key_features\n",
    "\n",
    "d_k = k.shape[-1]\n",
    "a = torch.softmax(q@k.t()/d_k,-1) \n",
    "\n",
    "print(\"a.shape=\",a.shape )\n",
    "\n",
    "#====考虑batch维度====\n",
    "Q = torch.randn(8,10)  #batch_size,query_features\n",
    "K = torch.randn(8,6,10) #batch_size,key_size,key_features\n",
    "\n",
    "d_k = K.shape[-1]\n",
    "A = torch.softmax(torch.einsum(\"in,ijn->ij\",Q,K)/d_k,-1) \n",
    "\n",
    "print(\"A.shape=\",A.shape )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aecd84f4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "a6c0433f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#性能测试\n",
    "\n",
    "#=====考虑batch维度====\n",
    "Q = torch.randn(80,100)    #batch_size,query_features\n",
    "K = torch.randn(80,100)    #batch_size,key_features\n",
    "W = torch.randn(50,100,100) #out_features,query_features,key_features\n",
    "b = torch.randn(50)       #out_features\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "f3510af4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "738 μs ± 43.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit \n",
    "A = torch.bilinear(Q,K,W,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "9b994c02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "265 μs ± 1.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit \n",
    "A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b249ddd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bf9b2e68",
   "metadata": {},
   "source": [
    "### 五，广播机制"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfd90e32",
   "metadata": {},
   "source": [
    "Pytorch的广播规则和numpy是一样的:\n",
    "\n",
    "* 1、如果张量的维度不同，将维度较小的张量进行扩展，直到两个张量的维度都一样。\n",
    "* 2、如果两个张量在某个维度上的长度是相同的，或者其中一个张量在该维度上的长度为1，那么我们就说这两个张量在该维度上是相容的。\n",
    "* 3、如果两个张量在所有维度上都是相容的，它们就能使用广播。\n",
    "* 4、广播之后，每个维度的长度将取两个张量在该维度长度的较大值。\n",
    "* 5、在任何一个维度上，如果一个张量的长度为1，另一个张量长度大于1，那么在该维度上，就好像是对第一个张量进行了复制。\n",
    "\n",
    "torch.broadcast_tensors可以将多个张量根据广播规则转换成相同的维度。\n",
    "\n",
    "维度扩展允许的操作有两种：\n",
    "1，增加一个维度\n",
    "2，对长度为1的维度进行复制扩展\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "24e8a6f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2, 3],\n",
      "        [2, 3, 4],\n",
      "        [3, 4, 5]])\n"
     ]
    }
   ],
   "source": [
    "a = torch.tensor([1,2,3])\n",
    "b = torch.tensor([[0,0,0],[1,1,1],[2,2,2]])\n",
    "print(b + a) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "16d4b93c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3],\n",
       "        [2, 3, 4],\n",
       "        [3, 4, 5]])"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat([a[None,:]]*3,dim=0) + b  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "f76bb5df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2, 3],\n",
      "        [1, 2, 3],\n",
      "        [1, 2, 3]]) \n",
      "\n",
      "tensor([[0, 0, 0],\n",
      "        [1, 1, 1],\n",
      "        [2, 2, 2]]) \n",
      "\n",
      "tensor([[1, 2, 3],\n",
      "        [2, 3, 4],\n",
      "        [3, 4, 5]])\n"
     ]
    }
   ],
   "source": [
    "a_broad,b_broad = torch.broadcast_tensors(a,b)\n",
    "print(a_broad,\"\\n\")\n",
    "print(b_broad,\"\\n\")\n",
    "print(a_broad + b_broad) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cb3114d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e466f833",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
